from genericpath import exists
import os
import sys
import boto3
import random
import warnings
import importlib
import numpy as np
import torch
import torch.multiprocessing as mp
import torch.backends.cudnn as cudnn
from torch.nn.parallel import DistributedDataParallel as DDP

from data import DataPrefetcher
from utils import resume, get_state_dict, save_checkpoint, save_checkpoint_oss
from utils.log import setup_logger, setup_writer
from utils.dist import configure_nccl, init_process_group, synchronize
from utils.default_argparse import default_argument_parser

from exp.trainer import Trainer


def main_worker(gpu, ngpus_per_node, args):
    configure_nccl()

    # ------------ set environment variables for distributed training ------------------------------------- #
    if args.rank is None:
        args.rank = int(os.getenv('RLAUNCH_REPLICA', '0'))

    args.gpu = gpu
    if ngpus_per_node > 1:
        args.rank = args.rank * ngpus_per_node + gpu

        # initialize process group
        init_process_group(args)

    # get the trainer
    trainer = Trainer()
    updated_config = trainer.update(args.exp_options)


    # make dir for experiment output
    file_name = os.path.join(args.output_dir, 'imagenet'+str(args.num_classes), args.experiment_name)
    if args.rank == 0:
        os.makedirs(file_name, exist_ok=True)
    synchronize()

    # setup the logger and writer
    logger = setup_logger(file_name, distributed_rank=args.rank, filename='train_log.txt', mode='a')
    writer = setup_writer(file_name, distributed_rank=args.rank)

    # setup model, dataloader and optimizer
    trainer.build_dataloader(args)
    trainer.build_model(args)
    trainer.build_optimizer(args)

    if args.rank == 0:
        logger.info('args: {}'.format(args))
        hyper_param = []
        for k in trainer.__dict__:
            if 'model' not in k:
                hyper_param.append(str(k) + '=' + str(trainer.__dict__[k]))
        logger.info('Hyper-parameters: {}'.format(', '.join(hyper_param)))

        if updated_config:
            logger.opt(ansi=True).info("List of override configs:\n<blue>{}</blue>\n".format(updated_config))

    if args.rank == 0:
        logger.info('Model: ')
        logger.info(str(trainer.model))

    # put the model onto gpu
    torch.cuda.set_device(gpu)
    trainer.model.cuda(gpu)
    if ngpus_per_node > 1:
        trainer.model = DDP(trainer.model, device_ids=[gpu])

    cudnn.benchmark = True

    # resume
    if args.resume:
        resume(args, trainer)

    # ------------------------ start training ------------------------------------------------------------ #

    if args.rank == 0:
        logger.info('Start training from iteration {},'
                    ' and the total training iterations is {}'.format(trainer.ITERS_PER_EPOCH * args.start_epoch + 1,
                                                                      trainer.total_iters))

    # oss
    if not args.ws:
        host = 'http://oss.i.brainpp.cn'
    else:
        host = 'http://oss-internal.hh-b.brainpp.cn'
    s3_client = boto3.client('s3', endpoint_url=host)
    bucket_name = 'hjq-oss'
    oss_saved_models_dir = 'ssl_saved_models'

    if args.debug:
        from IPython import embed
        embed()

    trainer.prefetcher = DataPrefetcher(trainer.data_loader, args.single_aug)

    for epoch in range(args.start_epoch, args.total_epochs):
        # set epoch
        trainer.epoch = epoch

        if trainer.prefetcher.next_input is None:
            if args.world_size > 1:
                trainer.data_loader.sampler.set_epoch(epoch)
            trainer.prefetcher = DataPrefetcher(trainer.data_loader, args.single_aug)

        trainer.train(args, logger, writer)

        # save models
        if args.rank == 0:
            state_dict = get_state_dict(trainer)
            save_checkpoint(state_dict, False, file_name, 'last_epoch')
            # save in oss
            if (epoch + 1) % 10 == 0 and not args.no_oss_saved:
                save_checkpoint_oss(s3_client, bucket_name, oss_saved_models_dir, state_dict, False, file_name,
                                    args.experiment_name + '-{}epoch'.format(epoch + 1))

    if args.rank == 0:
        logger.info("Pre-training of experiment: {} is done.".format(args.experiment_name))
        writer.close()


def main():
    args = default_argument_parser().parse_args()
    args.alpha = 1. / (args.num_classes - 1)
    args.weight = None
    if args.mode.startswith('fix') or args.mode.startswith('init'):
        path = './prototypes/weight1' if args.last_relu else './prototypes/weight0'
        args.weight = torch.Tensor(np.load(path + '_%dx%d.npy' % (args.num_classes, 2048)))

    # setup randomization
    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn(
            "You have chosen to seed training. This will turn on the CUDNN deterministic setting, "
            "which can slow down your training considerably! You may see unexpected behavior when restarting "
            "from checkpoints."
        )

    # multi-processing
    if args.num_machines is None:
        args.num_machines = int(os.getenv('RLAUNCH_REPLICA_TOTAL', '1'))

    print('Total number of using machines: {}'.format(args.num_machines))

    if args.debug:
        args.world_size = 1
        os.environ["CUDA_VISIBLE_DEVICES"] = '0'
    ngpus_per_node = torch.cuda.device_count()

    if ngpus_per_node > 1:
        args.world_size = ngpus_per_node * args.num_machines
        mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
    else:
        args.world_size = 1
        main_worker(0, ngpus_per_node, args)


if __name__ == "__main__":
    main()
